// -*- mode:c++ -*-

// Copyright (c) 2015 RISC-V Foundation
// Copyright (c) 2016-2017 The University of Virginia
// Copyright (c) 2020 Barkhausen Institut
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met: redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer;
// redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution;
// neither the name of the copyright holders nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

////////////////////////////////////////////////////////////////////
//
// Integer instructions
//

def template ImmDeclare {{
    //
    // Static instruction class for "%(mnemonic)s".
    //
    class %(class_name)s : public %(base_class)s
    {
      private:
        %(reg_idx_arr_decl)s;

      public:
        /// Constructor.
        %(class_name)s(ExtMachInst machInst);
        Fault execute(ExecContext *, trace::InstRecord *) const override;
        std::string generateDisassembly(Addr pc,
            const loader::SymbolTable *symtab) const override;
    };
}};

def template ImmConstructor {{
    %(class_name)s::%(class_name)s(ExtMachInst machInst)
        : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s)
    {
        %(set_reg_idx_arr)s;
        %(constructor)s;
        %(imm_code)s;
    }
}};

def template ImmExecute {{
    Fault
    %(class_name)s::execute(
        ExecContext *xc, trace::InstRecord *traceData) const
    {
        %(op_decl)s;
        %(op_rd)s;
        %(code)s;
        %(op_wb)s;
        return NoFault;
    }

    std::string
    %(class_name)s::generateDisassembly(Addr pc,
            const loader::SymbolTable *symtab) const
    {
        std::vector<RegId> indices = {%(regs)s};
        std::stringstream ss;
        ss << mnemonic << ' ';
        for (const RegId& idx: indices)
            ss << registerName(idx) << ", ";
        ss << imm;
        return ss.str();
    }
}};

def template CILuiExecute {{
    Fault
    %(class_name)s::execute(
        ExecContext *xc, trace::InstRecord *traceData) const
    {
        %(op_decl)s;
        %(op_rd)s;
        %(code)s;
        %(op_wb)s;
        return NoFault;
    }

    std::string
    %(class_name)s::generateDisassembly(Addr pc,
            const loader::SymbolTable *symtab) const
    {
        std::vector<RegId> indices = {%(regs)s};
        std::stringstream ss;
        ss << mnemonic << ' ';
        for (const RegId& idx: indices)
            ss << registerName(idx) << ", ";
        // To be compliant with GCC, the immediate is formated to a 20-bit
        // signed integer.
        ss << ((((uint64_t)imm) >> 12) & 0xFFFFF);
        return ss.str();
    }
}};

def template FenceExecute {{
    Fault
    %(class_name)s::execute(
        ExecContext *xc, trace::InstRecord *traceData) const
    {
        %(op_decl)s;
        %(op_rd)s;
        %(code)s;
        %(op_wb)s;
        return NoFault;
    }

    std::string
    %(class_name)s::generateDisassembly(Addr pc,
            const loader::SymbolTable *symtab) const
    {
        std::stringstream ss;
        ss << mnemonic;
        if (!FUNCT3) {
            ss << ' ';
            if (PRED & 0x8)
                ss << 'i';
            if (PRED & 0x4)
                ss << 'o';
            if (PRED & 0x2)
                ss << 'r';
            if (PRED & 0x1)
                ss << 'w';
            ss << ", ";
            if (SUCC & 0x8)
                ss << 'i';
            if (SUCC & 0x4)
                ss << 'o';
            if (SUCC & 0x2)
                ss << 'r';
            if (SUCC & 0x1)
                ss << 'w';
        }
        return ss.str();
    }
}};

def template BranchDeclare {{
    //
    // Static instruction class for "%(mnemonic)s".
    //
    class %(class_name)s : public %(base_class)s
    {
      private:
        %(reg_idx_arr_decl)s;

      public:
        /// Constructor.
        %(class_name)s(ExtMachInst machInst);
        Fault execute(ExecContext *, trace::InstRecord *) const override;

        std::string
        generateDisassembly(
                Addr pc, const loader::SymbolTable *symtab) const override;

        std::unique_ptr<PCStateBase> branchTarget(
                const PCStateBase &branch_pc) const override;

        using StaticInst::branchTarget;
    };
}};

def template BranchExecute {{
    Fault
    %(class_name)s::execute(ExecContext *xc,
        trace::InstRecord *traceData) const
    {
        %(op_decl)s;
        %(op_rd)s;
        %(code)s;
        %(op_wb)s;
        return NoFault;
    }

    std::unique_ptr<PCStateBase>
    %(class_name)s::branchTarget(const PCStateBase &branch_pc) const
    {
        auto &rpc = branch_pc.as<RiscvISA::PCState>();
        std::unique_ptr<PCState> npc(dynamic_cast<PCState*>(rpc.clone()));
        npc->set(rvSext(rpc.pc() + imm));
        return npc;
    }

    std::string
    %(class_name)s::generateDisassembly(
            Addr pc, const loader::SymbolTable *symtab) const
    {
        std::vector<RegId> indices = {%(regs)s};
        std::stringstream ss;
        ss << mnemonic << ' ';
        for (const RegId& idx: indices)
            ss << registerName(idx) << ", ";
        ss << imm;
        return ss.str();
    }
}};

def template JumpDeclare {{
    //
    // Static instruction class for "%(mnemonic)s".
    //
    class %(class_name)s : public %(base_class)s
    {
      private:
        %(reg_idx_arr_decl)s;

      public:
        /// Constructor.
        %(class_name)s(ExtMachInst machInst);
        Fault execute(ExecContext *, trace::InstRecord *) const override;

        std::string
        generateDisassembly(
                Addr pc, const loader::SymbolTable *symtab) const override;

        std::unique_ptr<PCStateBase> branchTarget(
                ThreadContext *tc) const override;

        using StaticInst::branchTarget;
    };
}};

def template JumpConstructor {{
    %(class_name)s::%(class_name)s(ExtMachInst machInst)
        : %(base_class)s("%(mnemonic)s", machInst, %(op_class)s)
    {
        %(set_reg_idx_arr)s;
        %(constructor)s;
        %(imm_code)s;
        if (QUADRANT != 0x3) {
            if (COPCODE == 4) {
                // Handle "c_jr" instruction, set "IsReturn" flag if RC1 is 1 or 5
                if (CFUNCT1 == 0 && (RC1 == 1 || RC1 == 5))
                    flags[IsReturn] = true;
                // Handle "c_jalr" instruction, set IsReturn if RC1 != ra
                if (CFUNCT1 == 1 && RC1 == 5)
                    flags[IsReturn] = true;
            }
        } else {
            bool rd_link = (RD == 1 || RD == 5);
            bool rs1_link = (RS1 == 1 || RS1 == 5);
            // Handle "jalr" and "jal" instruction,
            // set "IsCall" if RD is link register
            if (rd_link)
                flags[IsCall] = true;

            // Handle "Jalr" instruction
            if (FUNCT3 == 0x0 && OPCODE5 == 0x19) {
                // If RD is not link and RS1 is link, then pop RAS
                if (!rd_link && rs1_link) flags[IsReturn] = true;
                else if (rd_link) {
                    // If RD is link and RS1 is not link, push RAS
                    if (!rs1_link) flags[IsCall] = true;
                    // If RD is link and RS1 is link and rd != rs1
                    else if (rs1_link) {
                        if (RS1 != RD) {
                            // Both are link and RS1 == RD, pop then push
                            flags[IsReturn] = true;
                            flags[IsCall] = true;
                        } else {
                            // Both are link and RS1 != RD, push RAS
                            flags[IsCall] = true;
                        }
                    }
                }
            }
        }
    }
}};

def template JumpExecute {{
    Fault
    %(class_name)s::execute(
        ExecContext *xc, trace::InstRecord *traceData) const
    {
        %(op_decl)s;
        %(op_rd)s;
        %(code)s;
        %(op_wb)s;
        return NoFault;
    }

    std::unique_ptr<PCStateBase>
    %(class_name)s::branchTarget(ThreadContext *tc) const
    {
        PCStateBase *pc_ptr = tc->pcState().clone();
        pc_ptr->as<PCState>().set(
            rvSext((tc->getReg(srcRegIdx(0)) + imm) & ~0x1));
        return std::unique_ptr<PCStateBase>{pc_ptr};
    }

    std::string
    %(class_name)s::generateDisassembly(
            Addr pc, const loader::SymbolTable *symtab) const
    {
        std::stringstream ss;
        ss << mnemonic << ' ' << registerName(destRegIdx(0)) << ", "
           << imm << "(" << registerName(srcRegIdx(0)) << ")";
        return ss.str();
    }
}};

def template CSRExecute {{
    Fault
    %(class_name)s::execute(ExecContext *xc,
        trace::InstRecord *traceData) const
    {
        // We assume a riscv instruction is always run with a riscv ISA.
        auto isa = static_cast<RiscvISA::ISA*>(xc->tcBase()->getIsaPtr());
        auto& csr_data = isa->getCSRDataMap();
        MISA misa = isa->readMiscRegNoEffect(MISCREG_ISA);

        auto csr_data_it = csr_data.find(csr);
        if (csr_data_it == csr_data.end()) {
            return std::make_shared<IllegalInstFault>(
                    csprintf("Illegal CSR index %#x\n", csr), machInst);
        }

        RegIndex midx = csr_data_it->second.physIndex;
        std::string csrName = csr_data_it->second.name;
        if ((csr_data_it->second.rvTypes & (1 << machInst.rv_type)) == 0) {
            return std::make_shared<IllegalInstFault>(
                    csprintf("%s is not support in mode %d\n",
                             csrName,
                             machInst.rv_type),
                    machInst);
        }

        MISA csr_exts = csr_data_it->second.isaExts;
        if ((csr_exts & misa) != csr_exts) {
            return std::make_shared<IllegalInstFault>(
                    csprintf("%s is not support in the isa spec %d\n",
                             csrName),
                    machInst);
        }

        if (csr_data_it->second.requireSmrnmi && !isa->enableSmrnmi()) {
            return std::make_shared<IllegalInstFault>(
                    csprintf("%s requires Smrnmi enabled\n",
                             csrName),
                    machInst);
        }

        %(op_decl)s;
        %(op_rd)s;

        RegVal data = 0;
        auto lowestAllowedMode = (PrivilegeMode)bits(csr, 9, 8);
        auto pm = (PrivilegeMode)xc->readMiscReg(MISCREG_PRV);

        // H-extension quirk
        // Hypervisor CSRs have a lowestAllowedMode of 2
        // However, MISCREG_PRV contains PRV_S = 1 and the V-bit is off
        // when we are in HS-mode.

        // To allow HS-mode to access its CSRs we upgrade the
        // effective privilege to PRV_HS = 2.
        // Note that PRV_HS should NEVER be written back to
        // MISCREG_PRV as that might end up in mpp
        // which would violate the specification.

        if (misa.rvh && pm == PRV_S && !virtualizationEnabled(xc)) {
            pm = PRV_HS;
        }
        if (pm < lowestAllowedMode) {
            if (virtualizationEnabled(xc)) {
                return std::make_shared<VirtualInstFault>(
                    csprintf("[V] %s is not accessible in %s\n", csrName, pm),
                    machInst);
            }
            else {
                return std::make_shared<IllegalInstFault>(
                    csprintf("%s is not accessible in %s\n", csrName, pm),
                    machInst);
            }
        }

        // trick to get rid of const
        // so we can do swap if needed
        // this will shadow class member
        uint64_t csr = this->csr;

        // In virtualized supervisor mode, unmodified supervisor
        // software thinks it is accessing S-mode CSRs!
        // However, we remap these accesses to the virtual CSRs
        // hosted in different MISCREGs to implement the virtualization.
        if (misa.rvh && virtualizationEnabled(xc) && pm == PRV_S) {
            // Check if the supervisor CSR being accessed
            // is virtualized and remap it.
            // Args are passed by reference and modified if needed.
            isa->swapToVirtCSR(csr, midx, csrName);
        }

        // Checks tvm fields for ATP regs
        Fault tvm_fault = isa->tvmChecks(csr, pm, machInst);
        if (tvm_fault != NoFault) { return tvm_fault; }

        if (%(read_cond)s) {
            switch (csr) {
                case CSR_CYCLE:   case CSR_CYCLEH:
                case CSR_TIME:    case CSR_TIMEH:
                case CSR_INSTRET: case CSR_INSTRETH:
                case CSR_HPMCOUNTER03 ... CSR_HPMCOUNTER31:
                case CSR_HPMCOUNTER03H ... CSR_HPMCOUNTER31H:
                Fault hpm_fault = isa->hpmCounterCheck(midx, machInst);
                if (hpm_fault != NoFault)
                    return hpm_fault;
            }
            data = rvZext(isa->readCSR(xc, csr));
        }

        %(code)s;

        if (%(write_cond)s) {
            if (bits(csr, 11, 10) == 0x3) {
                return std::make_shared<IllegalInstFault>(
                        csprintf("CSR %s is read-only\n", csrName), machInst);
            }
            DPRINTF(RiscvMisc, "Writing %#x to CSR %s.\n",
                    data, csrName);
            isa->writeCSR(xc, csr, data);
        }
        %(op_wb)s;
        return NoFault;
    }
}};

def format ROp(code, *opt_flags) {{
    iop = InstObjParams(name, Name, 'RegOp', code, opt_flags)
    header_output = BasicDeclare.subst(iop)
    decoder_output = BasicConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = BasicExecute.subst(iop)
}};

def format IOp(code, imm_type='int64_t', imm_code='imm = sext<12>(IMM12);',
              *opt_flags) {{
    regs = ['destRegIdx(0)','srcRegIdx(0)']
    iop = InstObjParams(name, Name, 'ImmOp<%s>' % imm_type,
        {'imm_code': imm_code, 'code': code,
         'regs': ','.join(regs)}, opt_flags)
    header_output = ImmDeclare.subst(iop)
    decoder_output = ImmConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = ImmExecute.subst(iop)
}};

def format FenceOp(code, imm_type='int64_t', *opt_flags) {{
    regs = ['destRegIdx(0)','srcRegIdx(0)']
    iop = InstObjParams(name, Name, 'ImmOp<%s>' % imm_type,
        {'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
         'regs': ','.join(regs)}, opt_flags)
    header_output = ImmDeclare.subst(iop)
    decoder_output = ImmConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = FenceExecute.subst(iop)
}};

def format BOp(code, *opt_flags) {{
    imm_code = """
                imm = BIMM12BITS4TO1 << 1  |
                      BIMM12BITS10TO5 << 5 |
                      BIMM12BIT11 << 11    |
                      IMMSIGN << 12;
                imm = sext<13>(imm);
               """
    regs = ['srcRegIdx(0)','srcRegIdx(1)']
    iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
        {'code': code, 'imm_code': imm_code,
         'regs': ','.join(regs)}, opt_flags)
    header_output = BranchDeclare.subst(iop)
    decoder_output = ImmConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = BranchExecute.subst(iop)
}};

def format Jump(code, *opt_flags) {{
    regs = ['srcRegIdx(0)']
    iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
        {'code': code, 'imm_code': 'imm = sext<12>(IMM12);',
         'regs': ','.join(regs)}, opt_flags)
    header_output = JumpDeclare.subst(iop)
    decoder_output = JumpConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = JumpExecute.subst(iop)
}};

def format UOp(code, *opt_flags) {{
    regs = ['destRegIdx(0)']
    iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
        {'code': code, 'imm_code': 'imm = IMM20;',
         'regs': ','.join(regs)}, opt_flags)
    header_output = ImmDeclare.subst(iop)
    decoder_output = ImmConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = ImmExecute.subst(iop)
}};

def format JOp(code, *opt_flags) {{
    imm_code = """
                imm = UJIMMBITS10TO1 << 1   |
                      UJIMMBIT11 << 11      |
                      UJIMMBITS19TO12 << 12 |
                      IMMSIGN << 20;
                imm = sext<21>(imm);
               """
    pc = 'pc.set(pc.pc() + imm);'
    regs = ['destRegIdx(0)']
    iop = InstObjParams(name, Name, 'ImmOp<int64_t>',
        {'code': code, 'imm_code': imm_code,
         'regs': ','.join(regs)}, opt_flags)
    header_output = BranchDeclare.subst(iop)
    decoder_output = JumpConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = BranchExecute.subst(iop)
}};

def format SystemOp(code, *opt_flags) {{
    iop = InstObjParams(name, Name, 'SystemOp', code, opt_flags)
    header_output = BasicDeclare.subst(iop)
    decoder_output = BasicConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = BasicExecute.subst(iop)
}};

def format CSROp(code, read_cond, write_cond, *opt_flags) {{
    iop = InstObjParams(name, Name, 'CSROp',
        {'code': code, 'read_cond': read_cond,
         'write_cond': write_cond}, opt_flags)
    header_output = BasicDeclare.subst(iop)
    decoder_output = BasicConstructor.subst(iop)
    decode_block = BasicDecode.subst(iop)
    exec_output = CSRExecute.subst(iop)
}};
